import numpy as np
import matplotlib.pyplot as plt

"""
draw image: plot_images
parameters:
- images: pic series
- labels: label series

"""

class_names = ['buildings', 'forest', 'glacier', 'mountain', 'sea', 'street']


def plot_images(images, labels, class_names):
    fig, axes = plt.subplots(3, 5, figsize=(12, 6))
    axes = axes.flatten()
    for img, label, ax in zip(images, labels, axes):
        ax.imshow(img)
        ax.set_title(class_names[np.argmax(label)])
        ax.axis('off')
    plt.tight_layout()
    plt.show()
